Electron Density Prediction

using the e3nn repository

tutorial by: Joshua A. Rackers

code by:

DOI

@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  Michał Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller and
                  Josh Rackers},
  title        = {e3nn/e3nn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}

Using an E3NN network to predict electron densities

In this tutorial we show how an E3NN network can be used to predict electron densities. One reason this might be a good idea is that electron densities can be represented in a spherical harmonic basis on atom centers. This fits naturally with the E3NN framework.

In [12]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pickle
import torch
import random
from functools import partial

from e3nn.kernel import Kernel
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.non_linearities import rescaled_act
from e3nn.non_linearities.rescaled_act import relu, sigmoid
from e3nn.radial import CosineBasisModel
from e3nn.radial import GaussianRadialModel

torch.set_default_dtype(torch.float64)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

First we load the data. I have saved this in a pickle. For this particular example the dataset is ~6000 water dimer structures with density represented in "a2" density fitting basis set.

Oxygen: 8s, 4p, 4d

Hydrogen: 4s, 1p, 1d

In [2]:
## load density data
picklename = "./density_data/dimer_data.pckl"
with open(picklename, 'rb') as f:
    dataset_coeffs, dataset_onehot, dataset_geom, dataset_typemap, Rs_out_list, coeff_by_type = pickle.load(f)

Step 1: Define a model

The thing that makes this task tricky is predicting different numbers of spherical harmonics on each atomic center. To address this problem, we introduce a class, Mixer, to handle this.

Our network consists of 3 layers.

  1. Input layer: Geometry and one-hot atom encoding
  2. Middle layer: Rs = [(2,0),(1,1),(1,2)] (Up to L=2) Convolution + GatedBlock nonlinearity
  3. Final layer: 2 types of convolution, 1 for oxygen and 1 for hydrogen
In [3]:
## define model

class Mixer(torch.nn.Module):
    def __init__(self, Op, Rs_in_s, Rs_out):
        super().__init__()
        self.ops = torch.nn.ModuleList([
            Op(Rs_in, Rs_out)
            for Rs_in in Rs_in_s
        ])

    def forward(self, *args, n_norm=1):
        # It simply sums the different outputs
        y = 0
        for m, x in zip(self.ops, args):
            y += m(*x, n_norm=n_norm)
        return y


class Network(torch.nn.Module):
    def __init__(self, Rs_in, Rs_out_list, max_radius=3.0, number_of_basis=3, radial_layers=3, basistype="Gaussian"):
        super().__init__()

        #sp = rescaled_act.Softplus(beta=5)
        #sp = rescaled_act.ShiftedSoftplus(beta=5)
        sp = torch.nn.Tanh()

        # the [0] is just to get first_layer in stripped form.
        # will not work for Rs_in with more than L=0
        first_layer = Rs_in[0]
        last_shared_layer = (2,1,1)

        representations = [first_layer, last_shared_layer]
        representations = [[(mul, l) for l, mul in enumerate(rs)] for rs in representations]

        if (basistype == 'Gaussian'):
            rad_basis = GaussianRadialModel
        elif (basistype == 'Cosine'):
            rad_basis = CosineBasisModel
        else:
            print ("Only Gaussian and Cosine Radial basis are currently supported")

        RadialModel = partial(rad_basis, max_radius=max_radius,
                              number_of_basis=number_of_basis, h=100,
                              L=radial_layers, act=sp)

        K = partial(Kernel, RadialModel=RadialModel)
        C = partial(Convolution, K)
        M = partial(Mixer, C)  # wrap C to accept many input types

        def make_layer(Rs_in, Rs_out):
            act = GatedBlock(Rs_out, sp, sigmoid)
            conv = Convolution(K, Rs_in, act.Rs_in)
            return torch.nn.ModuleList([conv, act])

        self.layers = torch.nn.ModuleList([
            make_layer(Rs_layer_in,Rs_layer_out)
            for Rs_layer_in, Rs_layer_out in zip(representations[:-1], representations[1:])
        ])

        ## set up the split final layer
        m = []
        for rs in Rs_out_list:
            m.append(M([representations[-1], representations[-1]], rs))
        
        # final layer is indexed in order of atom type
        self.final_layer = torch.nn.ModuleList([
            m[i] for i in range(len(m))
        ])

    def forward(self, input, geometry, atom_type_map):
        output = input
        batch, N, _ = geometry.shape

        for conv, act in self.layers:
            output = conv(output, geometry, n_norm=N)
            output = act(output)

        ## split final layer
        geometry_list = []
        feature_list = []
        for i, item in enumerate(atom_type_map):
            geometry_list.append(geometry[0][item])
            feature_list.append(output[0][item])

        ## this is assuming that there are only two atom types!
        ## it should work, though for any arbitrary order of O and H in xyzfile!
        featuresO = feature_list[0].unsqueeze(0)
        featuresH = feature_list[1].unsqueeze(0)
        geometryO = geometry_list[0].unsqueeze(0)
        geometryH = geometry_list[1].unsqueeze(0)
        
        final_layer_output = []
        for i, layer in enumerate(self.final_layer):
            if (i == 0):
                final = layer((featuresO, geometryO, geometryO), (featuresH, geometryH, geometryO), n_norm = N)
            if (i == 1):
                final = layer((featuresO, geometryO, geometryH), (featuresH, geometryH, geometryH), n_norm = N) 
            final_layer_output.append(final)

        # return list of outputO and outputH
        output = final_layer_output

        return output

Step 2: Initilize the network

Let's initialize a rough model. Here's a brief description of the parameters:

  • max_radius: Distance of the center of the farthest radial function from convolution center
  • num_basis: Number of radial basis functions to use; more=finer grain detail
  • radial_layers: Number of layers in radial basis network (number of nonlinearity operations)
  • basistype: What type of functions to use for radial basis; default is Gaussians

We pass these parameters in as a dictionary so that we can save them for later use if we want to save the model.

Then we send the model to the GPU

The output shows us a helpful schematic of what kinds of operations our network is going to use.

In [4]:
## set arguments to network
maxradius = 3.0
numbasis = 20
radiallayers = 3
radialbasis = "Gaussian"
## set Rs_in based on onehot vector
Rs_in = [(len(dataset_typemap[0]),0)]

print("Rs_in:",Rs_in)
print("\nOxygen Rs_out:",Rs_out_list[0])
print("Hydrogen Rs_out:",Rs_out_list[1])

mydict = {"Rs_in":Rs_in, "Rs_out_list":(Rs_out_list), "max_radius":maxradius,
            "number_of_basis":numbasis, "radial_layers":radiallayers, 
            "basistype":radialbasis}

net = Network(**mydict)

#net.to(device)
Rs_in: [(2, 0)]

Oxygen Rs_out: [(8, 0), (4, 1), (4, 2)]
Hydrogen Rs_out: [(4, 0), (1, 1), (1, 2)]

Step 3: Set up training

From here, training the model looks virtually identical to any other training one might do with a typical neural network in pytorch. In this case we are going to use the Adam optimizer and minibatches.

In [5]:
## set up training

net.train()

optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
optimizer.zero_grad()
loss_fn = torch.nn.modules.loss.MSELoss()

max_steps = 2000
minibatch_size = 16

print (device)
cpu

Step 4: Train and evaluate test error

In [6]:
loss_minibatch = 0
for step in range(max_steps):
    i = random.randint(0, len(dataset_geom) - 3001)

    onehot = dataset_onehot[i]
    points = dataset_geom[i]
    atom_type_map = dataset_typemap[i]
    coeffs = dataset_coeffs[i]

    outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
    outputO = torch.flatten(outputO)
    outputH = torch.flatten(outputH)
    output = torch.cat((outputO,outputH),0).view(1,1,-1)

    loss = loss_fn(output, coeffs)
    step_loss = loss.item()
    loss.backward()
    loss_minibatch += step_loss

    if (step+1)%minibatch_size == 0:
        optimizer.step()
        optimizer.zero_grad()
        loss_minibatch = 0

    if step % 100 == 0:
        print('\nStep {0}, Loss {1}'.format(step, step_loss))
        j = random.randint(3000, len(dataset_geom) - 1)

        onehot = dataset_onehot[j]
        points = dataset_geom[j]*3
        atom_type_map = dataset_typemap[j]
        coeffs = dataset_coeffs[j]

        outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
        outputO = torch.flatten(outputO)
        outputH = torch.flatten(outputH)
        output = torch.cat((outputO,outputH),0).view(1,1,-1)

        loss = loss_fn(output.to(device), coeffs.to(device))
        print('\nTest Loss {0}'.format(loss.item()))
Step 0, Loss 0.1373628434784584

Test Loss 0.15495256346889363

Step 100, Loss 0.01203879998819049

Test Loss 0.09017569631856942

Step 200, Loss 0.005279411316785202

Test Loss 0.08689678552297997

Step 300, Loss 0.0031228677403023256

Test Loss 0.07401863503650252

Step 400, Loss 0.0024628998283588326

Test Loss 0.08607386924969587

Step 500, Loss 0.0007832273894890593

Test Loss 0.08013707338360407

Step 600, Loss 0.0008213186782783924

Test Loss 0.07821636209524374

Step 700, Loss 0.00030758708077491373

Test Loss 0.07826866377875111

Step 800, Loss 0.00016716070793977483

Test Loss 0.07466221299117239

Step 900, Loss 7.768865415319805e-05

Test Loss 0.07556705428450777

Step 1000, Loss 0.00021590383188972105

Test Loss 0.0760212542694313

Step 1100, Loss 0.00010018748614177089

Test Loss 0.07490439879119634

Step 1200, Loss 8.623894614598871e-05

Test Loss 0.07747538886646646

Step 1300, Loss 3.970026921748104e-05

Test Loss 0.07503508416307643

Step 1400, Loss 0.00010672196125032913

Test Loss 0.07638448602465367

Step 1500, Loss 2.6953220177630567e-05

Test Loss 0.07557312784346104

Step 1600, Loss 3.7505772993307385e-05

Test Loss 0.07537029362064365

Step 1700, Loss 3.329820470791767e-05

Test Loss 0.07666463182309799

Step 1800, Loss 4.93107589772935e-05

Test Loss 0.07482436074198115

Step 1900, Loss 6.127329517494144e-05

Test Loss 0.07450301653071653

Step 5: Sanity check

Let's check to see if the number of electrons is in the ballpark.

In [7]:
from density_analysis_utils import *

testnumelectrons(net,device,2,"./density_data/a2.gbs",dataset_onehot,dataset_geom,dataset_typemap,coeff_by_type)
Now testing number of electrons on 2 randomly selected dimers

Number of electrons for structure 4570:
   True: 20.014424436015563
     ML: 20.341280613187468

Number of electrons for structure 3055:
   True: 20.014015629367364
     ML: 20.711176789341163

Step 5: Plot predicted vs. true density

To do this we need three components for each function:

  1. The exponent alpha
  2. The normalization constant
  3. The learned coefficient

First we need to do some data wrangling. The basis set we're using has an annoying property that it has 'SP' functions. This mean one entry that specifies an S function and P function simultaneously.

Last we set up our radial function with the above values.

In [8]:
from density_analysis_utils import *
from e3nn.rs import dim, mul_dim

## define Gaussian Type Orbital basis functions
basis = lambda r, alpha, norm : norm * torch.exp(- alpha * r.unsqueeze(-1) **2)

## get exponent alphas
alphaO, alphaH = get_exponents('./density_data/a2.gbs')

## get normalization constants
normO, normH = parse_whole_normfile('./density_data/a2_norm.dat')
normO = torch.FloatTensor(normO)
normH = torch.FloatTensor(normH)

## get spherical harmonic normalization constants
Rs_out_O = Rs_out_list[0]
Rs_out_H = Rs_out_list[1]
sph_normsO, sph_normsH = get_spherical_harmonic_norms(Rs_out_O,Rs_out_H)

basis_on_r_O = partial(basis, alpha=alphaO, norm=normO)
basis_on_r_H = partial(basis, alpha=alphaH, norm=normH)

assert mul_dim(Rs_out_O) == normO.shape[0]
assert mul_dim(Rs_out_H) == normH.shape[0]
In [9]:
# pick a random structure to test

dimer_num = 4321
onehot = dataset_onehot[dimer_num]
points = dataset_geom[dimer_num]
atom_type_map = dataset_typemap[dimer_num]
outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)

outputO = outputO.data.cpu().numpy()
outputH = outputH.data.cpu().numpy()
In [18]:
from spherical import plot_data_on_grid
import e3nn.o3 as o3

## get the functions
f_list = []
# loop over types
for i, type in enumerate(atom_type_map):
    # loop over atoms
    for count, atom in enumerate(type):
        tot_f = 0
        center = points.data.squeeze().numpy()[atom]
        # oxygens
        if i == 0:
            #vsf = VisualizeSphericalFunction(basis_on_r_O, Rs_out_O, o3.spherical_harmonics_xyz)
            r, f = plot_data_on_grid(5.0, basis_on_r_O, Rs_out_O,
                                         n=20, center=center)
            for j, val in enumerate(outputO.squeeze()[count]):
                c = val
                norm = sph_normsO[j]
                # sum up contributions from every basis function
                tot_f += c*f[:,j]/norm
        # hydrogens
        if i == 1:
            #vsf = VisualizeSphericalFunction(basis_on_r_H, Rs_out_H, o3.spherical_harmonics_xyz)
            r, f = plot_data_on_grid(5.0, basis_on_r_H, Rs_out_H,
                                         n=20, center=center)
            for j, val in enumerate(outputH.squeeze()[count]):
                c = val
                norm = sph_normsH[j]
                # sum up contributions from every basis function
                tot_f += c*f[:,j]/norm

        f_list.append(tot_f)

all_atom_f = sum(f_list)
print(all_atom_f.max())
tensor(34.2586)
In [19]:
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go

plot_max = float(all_atom_f.max())

fig = go.Figure(data=go.Volume(
    x=r[:,0],
    y=r[:,1],
    z=r[:,2],
    #value=c * f[:, i],
    value=all_atom_f,
    isomin=-0.005*plot_max,
    isomax=0.005*plot_max,
    #isomin=-0.03,
    #isomax=0.03,
    opacity=0.3, # needs to be small to see through all surfaces
    opacityscale="uniform",
    surface_count=50, # needs to be a large number for good volume rendering
    colorscale='RdBu'))
    
xs = points.data.squeeze().numpy()[:,0]
ys = points.data.squeeze().numpy()[:,1]
zs = points.data.squeeze().numpy()[:,2]
fig.add_scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=12,color='Black',opacity=1.0))

fig.show()
In [ ]:
 
In [ ]:

In [ ]: